/**
 * \file dnn/src/rocm/elemwise_helper.h.hip
 *
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#pragma once

#include "hip_header.h"
#include "src/rocm/utils.h.hip"
#include "src/common/elemwise_helper.cuh"
#include "src/rocm/int_fastdiv.h.hip"

/*
 * please note that all arithmetics on GPU are 32-bit for best performance; this
 * limits max possible size
 */

namespace megdnn {
namespace rocm {

//! internals for element-wise
namespace elemwise_intl {
#define devfunc __device__ __forceinline__

/*!
 * \brief get hip launch specs for element-wise kernel
 * \param kern kernel function address
 * \param size total size of elements
 */
void get_launch_spec(const void* kern, size_t size, int* grid_size,
                     int* block_size);

MEGDNN_NORETURN void on_bad_ndim(int ndim);

/*!
 * \brief broadcast type
 * BCAST_x[0]x[1]...: x[i] == !stride[i]
 */
enum BcastType { BCAST_OTHER, BCAST_101, BCAST_10, BCAST_01, BCAST_FULL };

/*!
 * \brief visitor to access an elemeent in a tensor at given logic index
 * \tparam ctype plain element ctype (i.e. ctype in DTypeTrait)
 * \tparam brdcast_mask bit mask for broadcast of params; (i.e. stride[i] is
 *      0 iff (brdcast_mask & (1<<(ndim-1-i))) is 1.
 *
 * host interface:
 *      void host_init(
 *              const TensorND &tensor, int grid_size, int block_size)
 *
 * device interface:
 *      void thread_init(uint32_t idx)
 *          called on thread entrance, with logical indexing; the index may
 *          go beyond buffer range
 *
 *      ctype* ptr()
 *          return buffer pointer; can be used by specialized OpCaller
 *
 *      void next()
 *          called before moving to next chunk on each thread
 *
 *      int offset(uint32_t idx)
 *          get physical offset from logical index
 *
 *      ctype& at(uint32_t idx)
 *          ptr()[offset(idx)]
 *
 */
template <int ndim, typename ctype, BcastType brd_type>
class ParamElemVisitor;

#define PARAM_ELEM_VISITOR_COMMON_DEV      \
    devfunc ctype* ptr() { return m_ptr; } \
    devfunc ctype& at(uint32_t idx) { return m_ptr[offset(idx)]; }

//! specialization for BCAST_OTHER
template <int ndim, typename ctype>
class ParamElemVisitor<ndim, ctype, BCAST_OTHER> {
    ctype* __restrict m_ptr;
    int m_stride[ndim];

    //! m_shape_highdim[i] = original_shape[i + 1]
#ifdef _MSC_VER
    Uint32Fastdiv m_shape_highdim[ndim > 1 ? ndim - 1 : 1];
#else
    Uint32Fastdiv m_shape_highdim[ndim - 1];
#endif

public:
    static const int NDIM = ndim;

    void host_init(const TensorND& rv, int grid_size, int block_size);

#if MEGDNN_CC_CUDA
    devfunc void thread_init(uint32_t) {}

    devfunc void next() {}

    devfunc int offset(uint32_t idx) {
        int offset = 0;
#pragma unroll
        for (int i = ndim - 1; i >= 1; --i) {
            Uint32Fastdiv& shp = m_shape_highdim[i - 1];
            uint32_t idx_div = idx / shp;
            offset += (idx - idx_div * shp.divisor()) * m_stride[i];
            idx = idx_div;
        }
        offset += idx * m_stride[0];
        return offset;
    }

    PARAM_ELEM_VISITOR_COMMON_DEV
#endif
};

/*!
 * \brief specialization for ndim == 3 and BCAST_101
 * (for dimshuffle 'x', 0, 'x')
 *
 * visit: idx / m_shape2 % m_shape1
 */
template <typename ctype>
class ParamElemVisitor<3, ctype, BCAST_101> {
    ctype* __restrict m_ptr;
    StridedDivSeq2 m_shape12;
    int m_stride1;

public:
    static const int NDIM = 3;

    void host_init(const TensorND& rv, int grid_size, int block_size);

#if MEGDNN_CC_CUDA
    devfunc void thread_init(uint32_t idx) { m_shape12.device_init(idx); }

    devfunc void next() { m_shape12.next(); }

    devfunc int offset(uint32_t /* idx */) {
        return m_shape12.get() * m_stride1;
    }

    PARAM_ELEM_VISITOR_COMMON_DEV
#endif
};

/*!
 * \brief specialization for ndim == 2 and BCAST_10
 *
 * visit: idx % m_shape1
 */
template <typename ctype>
class ParamElemVisitor<2, ctype, BCAST_10> {
    ctype* __restrict m_ptr;
    StridedDivSeq<false> m_shape1;
    int m_stride1;

public:
    static const int NDIM = 2;

    void host_init(const TensorND& rv, int grid_size, int block_size);

#if MEGDNN_CC_CUDA
    devfunc void thread_init(uint32_t idx) { m_shape1.device_init(idx); }

    devfunc void next() { m_shape1.next(); }

    devfunc int offset(uint32_t /* idx */) { return m_shape1.r() * m_stride1; }

    PARAM_ELEM_VISITOR_COMMON_DEV
#endif
};

/*!
 * \brief specialization for ndim == 2 and BCAST_01
 *
 * visit: idx / shape1
 */
template <typename ctype>
class ParamElemVisitor<2, ctype, BCAST_01> {
    ctype* __restrict m_ptr;
    StridedDivSeq<true> m_shape1;
    int m_stride0;

public:
    static const int NDIM = 2;

    void host_init(const TensorND& rv, int grid_size, int block_size);

    devfunc void thread_init(uint32_t idx) { m_shape1.device_init(idx); }

    devfunc void next() { m_shape1.next(); }

    devfunc int offset(uint32_t /* idx */) { return m_shape1.q() * m_stride0; }

    PARAM_ELEM_VISITOR_COMMON_DEV
};

//! specialization for ndim == 1 and BCAST_FULL
template <typename ctype>
class ParamElemVisitor<1, ctype, BCAST_FULL> {
    ctype* __restrict m_ptr;

public:
    static const int NDIM = 1;

    void host_init(const TensorND& rv, int grid_size, int block_size);

#if MEGDNN_CC_CUDA
    devfunc void thread_init(uint32_t) {}

    devfunc void next() {}

    devfunc int offset(uint32_t idx) {
        MEGDNN_MARK_USED_VAR(idx);
        return 0;
    }

    PARAM_ELEM_VISITOR_COMMON_DEV
#endif
};

#undef PARAM_ELEM_VISITOR_COMMON_DEV

#if MEGDNN_CC_CUDA
/*
 * OpCaller is used to invoke user operator with loaded element arguments.
 *
 * device interface:
 *      void thread_init(uint32_t idx);
 *
 *      void on(uint32_t idx);
 *
 *      void next();
 */

/*!
 * \brief call user op directly without visiting any params (i.e. arity ==
 *      0)
 */
template <class Op>
struct OpCallerNull {
    Op op;

    devfunc void thread_init(uint32_t) {}

    devfunc void on(uint32_t idx) { op(idx); }

    devfunc void next() {}
};

/*!
 * \brief call an operator whose each param are promted to the same ndim and
 *      brdcast_mask
 * \tparam PVis ParamElemVisitor class
 */
template <class Op, int arity, class PVis>
struct OpCallerUniform;

//! specialization for arity == 1
template <class Op, class PVis>
struct OpCallerUniform<Op, 1, PVis> {
    Op op;
    PVis par[1];

    devfunc void thread_init(uint32_t idx) { par[0].thread_init(idx); }

    devfunc void on(uint32_t idx) { op(idx, par[0].at(idx)); }

    devfunc void next() { par[0].next(); }
};
//! specialization for arity == 2
template <class Op, class PVis>
struct OpCallerUniform<Op, 2, PVis> {
    Op op;
    PVis par[2];

    devfunc void thread_init(uint32_t idx) {
        par[0].thread_init(idx);
        par[1].thread_init(idx);
    }

    devfunc void on(uint32_t idx) { op(idx, par[0].at(idx), par[1].at(idx)); }

    devfunc void next() {
        par[0].next();
        par[1].next();
    }
};
//! specialization for arity == 3
template <class Op, class PVis>
struct OpCallerUniform<Op, 3, PVis> {
    Op op;
    PVis par[3];

    devfunc void thread_init(uint32_t idx) {
        par[0].thread_init(idx);
        par[1].thread_init(idx);
        par[2].thread_init(idx);
    }

    devfunc void on(uint32_t idx) {
        op(idx, par[0].at(idx), par[1].at(idx), par[2].at(idx));
    }

    devfunc void next() {
        par[0].next();
        par[1].next();
        par[2].next();
    }
};

/*!
 * \brief call binary (i.e. arity == 2) operator with different param
 *      visitors
 */
template <class Op, class PVis0, class PVis1>
struct OpCallerBinary {
    Op op;
    PVis0 par0;
    PVis1 par1;

    devfunc void thread_init(uint32_t idx) {
        par0.thread_init(idx);
        par1.thread_init(idx);
    }

    devfunc void on(uint32_t idx) { op(idx, par0.at(idx), par1.at(idx)); }

    devfunc void next() {
        par0.next();
        par1.next();
    }
};

template <class OpCaller>
__global__ void cuda_kern(OpCaller op_caller, uint32_t size) {
    uint32_t idx = hipBlockIdx_x * hipBlockDim_x + hipThreadIdx_x,
             delta = hipBlockDim_x * hipGridDim_x;
    // each thread works on at most 3 elements; see get_launch_spec
    op_caller.thread_init(idx);
    if (idx < size) {
        op_caller.on(idx);
        idx += delta;
        if (idx < size) {
            op_caller.next();
            op_caller.on(idx);
            idx += delta;
            if (idx < size) {
                op_caller.next();
                op_caller.on(idx);
            }
        }
    }
}

//! invoke a user Op passed to run_elemwise
template <class Op, typename ctype, int arity>
class UserOpInvoker;

//! run op by promoting all params to same ndim
template <class Op, typename ctype, int arity>
class UserOpInvokerToSameNdim {
    const ElemwiseOpParamN<arity>& m_param;
    hipStream_t m_stream;
    const Op& m_op;

    void dispatch0() {
        switch (m_param.max_ndim) {
#define cb(ndim) \
    case ndim:   \
        return dispatch1<ndim>();
            MEGDNN_FOREACH_TENSOR_NDIM(cb)
#undef cb
        }
        on_bad_ndim(m_param.max_ndim);
    }

    template <int ndim>
    void dispatch1() {
        typedef OpCallerUniform<Op, arity,
                                ParamElemVisitor<ndim, ctype, BCAST_OTHER>>
                Caller;
        size_t size = m_param.size;
        int grid_size, block_size;
        void (*fptr)(Caller, uint32_t) = cuda_kern<Caller>;
        get_launch_spec(reinterpret_cast<const void*>(fptr), size, &grid_size,
                        &block_size);

        Caller caller;
        caller.op = m_op;
        for (int i = 0; i < arity; ++i)
            caller.par[i].host_init(m_param[i], grid_size, block_size);

        hipLaunchKernelGGL(fptr,
                           dim3(grid_size), dim3(block_size), 0, m_stream,
                           caller, size);
        after_kernel_launch();
    }

public:
    UserOpInvokerToSameNdim(const ElemwiseOpParamN<arity>& param,
                            hipStream_t stream, const Op& op)
            : m_param(param), m_stream(stream), m_op(op) {
        dispatch0();
    }
};

//! implement general case by UserOpInvokerToSameNdim
template <class Op, typename ctype, int arity>
class UserOpInvoker : public UserOpInvokerToSameNdim<Op, ctype, arity> {
public:
    UserOpInvoker(const ElemwiseOpParamN<arity>& param, hipStream_t stream,
                  const Op& op)
            : UserOpInvokerToSameNdim<Op, ctype, arity>(param, stream, op) {}
};

//! specialization for arity == 0
template <class Op, typename ctype>
class UserOpInvoker<Op, ctype, 0> {
public:
    UserOpInvoker(const ElemwiseOpParamN<0>& param, hipStream_t stream,
                  const Op& op) {
        size_t size = param.size;
        typedef OpCallerNull<Op> Caller;
        Caller caller;
        caller.op = op;
        int grid_size, block_size;
        void (*fptr)(Caller, uint32_t) = cuda_kern<Caller>;
        get_launch_spec(reinterpret_cast<const void*>(fptr), size, &grid_size,
                        &block_size);
        hipLaunchKernelGGL(fptr,
                           dim3(grid_size), dim3(block_size), 0, stream, caller,
                           size);
        after_kernel_launch();
    }
};

#define DEFINE_BRDCAST_DISPATCH_RECEIVERS(_cb_header, _cb_dispatch, _stride) \
    _cb_header(1) {                                                          \
        const ptrdiff_t* stride = _stride;                                   \
        if (!stride[0]) {                                                    \
            return _cb_dispatch(1, BCAST_FULL);                              \
        }                                                                    \
        _cb_dispatch(1, BCAST_OTHER);                                        \
    }                                                                        \
    _cb_header(2) {                                                          \
        const ptrdiff_t* stride = _stride;                                   \
        if (!stride[0] && stride[1]) {                                       \
            return _cb_dispatch(2, BCAST_10);                                \
        }                                                                    \
        if (stride[0] && !stride[1]) {                                       \
            return _cb_dispatch(2, BCAST_01);                                \
        }                                                                    \
        _cb_dispatch(2, BCAST_OTHER);                                        \
    }                                                                        \
    _cb_header(3) {                                                          \
        const ptrdiff_t* stride = _stride;                                   \
        if (!stride[0] && stride[1] && !stride[2]) {                         \
            return _cb_dispatch(3, BCAST_101);                               \
        }                                                                    \
        _cb_dispatch(3, BCAST_OTHER);                                        \
    }

//! specialization for binary opr
template <class Op, typename ctype>
class UserOpInvoker<Op, ctype, 2> {
    bool m_invoked;
    const ElemwiseOpParamN<2>& m_param;
    hipStream_t m_stream;
    const Op& m_op;

    void fallback() {
        megdnn_assert(!m_invoked);
        UserOpInvokerToSameNdim<Op, ctype, 2>(m_param, m_stream, m_op);
        m_invoked = true;
    }

    void dispatch0() {
        switch (m_param[0].layout.ndim) {
#define cb(ndim) \
    case ndim:   \
        return dispatch1_##ndim();
            MEGDNN_FOREACH_TENSOR_NDIM_SMALL(cb)
#undef cb
        }
        fallback();
    }

#define cb_header(ndim) void dispatch1_##ndim()
#define cb_dispatch(ndim, brdcast_mask) \
    dispatch2<ParamElemVisitor<ndim, ctype, brdcast_mask>>()
    DEFINE_BRDCAST_DISPATCH_RECEIVERS(cb_header, cb_dispatch,
                                      m_param[0].layout.stride)
#undef cb_header
#undef cb_dispatch

    template <class PVis0>
    void dispatch2() {
        switch (m_param[1].layout.ndim) {
#define cb(ndim) \
    case ndim:   \
        return dispatch3_##ndim<PVis0>();
            MEGDNN_FOREACH_TENSOR_NDIM_SMALL(cb)
#undef cb
        }
        fallback();
    }

#define cb_header(ndim)    \
    template <class PVis0> \
    void dispatch3_##ndim()
#define cb_dispatch(ndim, brdcast_mask) \
    do_run<PVis0, ParamElemVisitor<ndim, ctype, brdcast_mask>>()
    DEFINE_BRDCAST_DISPATCH_RECEIVERS(cb_header, cb_dispatch,
                                      m_param[1].layout.stride)
#undef cb_header
#undef cb_dispatch

    template <class PVis0, class PVis1>
    void do_run() {
        megdnn_assert(!m_invoked);
        m_invoked = true;
        typedef OpCallerBinary<Op, PVis0, PVis1> Caller;
        int grid_size, block_size;
        void (*fptr)(Caller, uint32_t) = cuda_kern<Caller>;
        size_t size = m_param.size;
        get_launch_spec(reinterpret_cast<const void*>(fptr), size, &grid_size,
                        &block_size);
        Caller caller;
        caller.op = m_op;
        caller.par0.host_init(m_param[0], grid_size, block_size);
        caller.par1.host_init(m_param[1], grid_size, block_size);
        hipLaunchKernelGGL(fptr,
                           dim3(grid_size), dim3(block_size), 0, m_stream,
                           caller, size);
        after_kernel_launch();
    }

public:
    UserOpInvoker(const ElemwiseOpParamN<2>& param, hipStream_t stream,
                  const Op& op)
            : m_param(param), m_stream(stream), m_op(op) {
        m_invoked = false;
        dispatch0();
        megdnn_assert(m_invoked);
    }
};

#undef DEFINE_BRDCAST_DISPATCH_RECEIVERS

#endif  // MEGDNN_CC_CUDA

#undef devfunc
}  // namespace elemwise_intl

/*!
 * \brief general element-wise kernel launcher
 *
 * \tparam arity number of params for the operator
 * \param param param values for the operator; must have been initialized (i.e.
 *      by calling ElemwiseOpParamN::init_from_given_tensor). The params
 *      can have arbitrary layouts, as long as they share the same total number
 *      of elements.
 * \param op callable with a signature compatible with
 *      `void op(uint32_t idx, ctype& param0, ..., ctype& param[arity - 1])`
 *      if arity == 0, there is only an `idx` input
 */
template <class Op, typename ctype, int arity>
void run_elemwise(const ElemwiseOpParamN<arity>& param, hipStream_t stream,
                  const Op& op = Op());

#if MEGDNN_CC_CUDA
template <class Op, typename ctype, int arity>
void run_elemwise(const ElemwiseOpParamN<arity>& param, hipStream_t stream,
                  const Op& op) {
    param.assert_initialized();
    elemwise_intl::UserOpInvoker<Op, ctype, arity>(param, stream, op);
}

/*!
 * \brief explicit instantialization of run_elemwise for given template params;
 *      used in .cu files, so corresponding run_elemwise can be called from .cpp
 */
#define INST_RUN_ELEMWISE(Op, ctype, arity)       \
    template void run_elemwise<Op, ctype, arity>( \
            const ElemwiseOpParamN<arity>&, hipStream_t, const Op&)
#endif  // MEGDNN_CC_CUDA

}  // namespace rocm
}  // namespace megdnn

// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

